# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# MIT License

# Copyright (c) 2024 Ending Hsiao

# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:

# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.

# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

import numpy as np

# https://github.com/andrewwillmott/sh-lib

kSqrt02_01 = np.sqrt(2.0 / 1.0)
kSqrt01_02 = np.sqrt(1.0 / 2.0)
kSqrt03_02 = np.sqrt(3.0 / 2.0)
kSqrt01_03 = np.sqrt(1.0 / 3.0)
kSqrt02_03 = np.sqrt(2.0 / 3.0)
kSqrt04_03 = np.sqrt(4.0 / 3.0)
kSqrt01_04 = np.sqrt(1.0 / 4.0)
kSqrt03_04 = np.sqrt(3.0 / 4.0)
kSqrt05_04 = np.sqrt(5.0 / 4.0)
kSqrt01_05 = np.sqrt(1.0 / 5.0)
kSqrt02_05 = np.sqrt(2.0 / 5.0)
kSqrt03_05 = np.sqrt(3.0 / 5.0)
kSqrt04_05 = np.sqrt(4.0 / 5.0)
kSqrt06_05 = np.sqrt(6.0 / 5.0)
kSqrt08_05 = np.sqrt(8.0 / 5.0)
kSqrt09_05 = np.sqrt(9.0 / 5.0)
kSqrt01_06 = np.sqrt(1.0 / 6.0)
kSqrt05_06 = np.sqrt(5.0 / 6.0)
kSqrt07_06 = np.sqrt(7.0 / 6.0)
kSqrt02_07 = np.sqrt(02.0 / 7.0)
kSqrt06_07 = np.sqrt(6.0 / 7.0)
kSqrt10_07 = np.sqrt(10.0 / 7.0)
kSqrt12_07 = np.sqrt(12.0 / 7.0)
kSqrt15_07 = np.sqrt(15.0 / 7.0)
kSqrt16_07 = np.sqrt(16.0 / 7.0)
kSqrt01_08 = np.sqrt(1.0 / 8.0)
kSqrt03_08 = np.sqrt(3.0 / 8.0)
kSqrt05_08 = np.sqrt(5.0 / 8.0)
kSqrt07_08 = np.sqrt(7.0 / 8.0)
kSqrt09_08 = np.sqrt(9.0 / 8.0)
kSqrt05_09 = np.sqrt(5.0 / 9.0)
kSqrt08_09 = np.sqrt(8.0 / 9.0)
kSqrt01_10 = np.sqrt(1.0 / 10.0)
kSqrt03_10 = np.sqrt(3.0 / 10.0)
kSqrt07_10 = np.sqrt(7.0 / 10.0)
kSqrt09_10 = np.sqrt(9.0 / 10.0)
kSqrt01_12 = np.sqrt(1.0 / 12.0)
kSqrt07_12 = np.sqrt(7.0 / 12.0)
kSqrt11_12 = np.sqrt(11.0 / 12.0)
kSqrt01_14 = np.sqrt(1.0 / 14.0)
kSqrt03_14 = np.sqrt(3.0 / 14.0)
kSqrt15_14 = np.sqrt(15.0 / 14.0)
kSqrt04_15 = np.sqrt(4.0 / 15.0)
kSqrt07_15 = np.sqrt(7.0 / 10.0)
kSqrt14_15 = np.sqrt(14.0 / 15.0)
kSqrt16_15 = np.sqrt(16.0 / 15.0)
kSqrt01_16 = np.sqrt(1.0 / 16.0)
kSqrt03_16 = np.sqrt(3.0 / 16.0)
kSqrt07_16 = np.sqrt(7.0 / 16.0)
kSqrt15_16 = np.sqrt(15.0 / 16.0)
kSqrt01_18 = np.sqrt(1.0 / 18.0)
kSqrt01_24 = np.sqrt(1.0 / 24.0)
kSqrt03_25 = np.sqrt(3.0 / 25.0)
kSqrt09_25 = np.sqrt(9.0 / 25.0)
kSqrt14_25 = np.sqrt(14.0 / 25.0)
kSqrt16_25 = np.sqrt(16.0 / 25.0)
kSqrt18_25 = np.sqrt(18.0 / 25.0)
kSqrt21_25 = np.sqrt(21.0 / 25.0)
kSqrt24_25 = np.sqrt(24.0 / 25.0)
kSqrt03_28 = np.sqrt(3.0 / 28.0)
kSqrt05_28 = np.sqrt(5.0 / 28.0)
kSqrt01_30 = np.sqrt(1.0 / 30.0)
kSqrt01_32 = np.sqrt(1.0 / 32.0)
kSqrt03_32 = np.sqrt(3.0 / 32.0)
kSqrt15_32 = np.sqrt(15.0 / 32.0)
kSqrt21_32 = np.sqrt(21.0 / 32.0)
kSqrt11_36 = np.sqrt(11.0 / 36.0)
kSqrt35_36 = np.sqrt(35.0 / 36.0)
kSqrt01_50 = np.sqrt(1.0 / 50.0)
kSqrt03_50 = np.sqrt(3.0 / 50.0)
kSqrt21_50 = np.sqrt(21.0 / 50.0)
kSqrt15_56 = np.sqrt(15.0 / 56.0)
kSqrt01_60 = np.sqrt(1.0 / 60.0)
kSqrt01_112 = np.sqrt(1.0 / 112.0)
kSqrt03_112 = np.sqrt(3.0 / 112.0)
kSqrt15_112 = np.sqrt(15.0 / 112.0)


def get_sh1(R):
    return np.roll(np.roll(R, -1, axis=-1), -1, axis=-2)


def get_sh2(sh1):
    sh2 = np.zeros((5, 5), dtype=sh1.dtype)

    sh2[0][0] = kSqrt01_04 * (
        (sh1[2][2] * sh1[0][0] + sh1[2][0] * sh1[0][2])
        + (sh1[0][2] * sh1[2][0] + sh1[0][0] * sh1[2][2])
    )
    sh2[0][1] = sh1[2][1] * sh1[0][0] + sh1[0][1] * sh1[2][0]
    sh2[0][2] = kSqrt03_04 * (sh1[2][1] * sh1[0][1] + sh1[0][1] * sh1[2][1])
    sh2[0][3] = sh1[2][1] * sh1[0][2] + sh1[0][1] * sh1[2][2]
    sh2[0][4] = kSqrt01_04 * (
        (sh1[2][2] * sh1[0][2] - sh1[2][0] * sh1[0][0])
        + (sh1[0][2] * sh1[2][2] - sh1[0][0] * sh1[2][0])
    )

    sh2[1][0] = kSqrt01_04 * (
        (sh1[1][2] * sh1[0][0] + sh1[1][0] * sh1[0][2])
        + (sh1[0][2] * sh1[1][0] + sh1[0][0] * sh1[1][2])
    )
    sh2[1][1] = sh1[1][1] * sh1[0][0] + sh1[0][1] * sh1[1][0]
    sh2[1][2] = kSqrt03_04 * (sh1[1][1] * sh1[0][1] + sh1[0][1] * sh1[1][1])
    sh2[1][3] = sh1[1][1] * sh1[0][2] + sh1[0][1] * sh1[1][2]
    sh2[1][4] = kSqrt01_04 * (
        (sh1[1][2] * sh1[0][2] - sh1[1][0] * sh1[0][0])
        + (sh1[0][2] * sh1[1][2] - sh1[0][0] * sh1[1][0])
    )

    sh2[2][0] = kSqrt01_03 * (
        sh1[1][2] * sh1[1][0] + sh1[1][0] * sh1[1][2]
    ) - kSqrt01_12 * (
        (sh1[2][2] * sh1[2][0] + sh1[2][0] * sh1[2][2])
        + (sh1[0][2] * sh1[0][0] + sh1[0][0] * sh1[0][2])
    )
    sh2[2][1] = kSqrt04_03 * sh1[1][1] * sh1[1][0] - kSqrt01_03 * (
        sh1[2][1] * sh1[2][0] + sh1[0][1] * sh1[0][0]
    )
    sh2[2][2] = sh1[1][1] * sh1[1][1] - kSqrt01_04 * (
        sh1[2][1] * sh1[2][1] + sh1[0][1] * sh1[0][1]
    )
    sh2[2][3] = kSqrt04_03 * sh1[1][1] * sh1[1][2] - kSqrt01_03 * (
        sh1[2][1] * sh1[2][2] + sh1[0][1] * sh1[0][2]
    )
    sh2[2][4] = kSqrt01_03 * (
        sh1[1][2] * sh1[1][2] - sh1[1][0] * sh1[1][0]
    ) - kSqrt01_12 * (
        (sh1[2][2] * sh1[2][2] - sh1[2][0] * sh1[2][0])
        + (sh1[0][2] * sh1[0][2] - sh1[0][0] * sh1[0][0])
    )

    sh2[3][0] = kSqrt01_04 * (
        (sh1[1][2] * sh1[2][0] + sh1[1][0] * sh1[2][2])
        + (sh1[2][2] * sh1[1][0] + sh1[2][0] * sh1[1][2])
    )
    sh2[3][1] = sh1[1][1] * sh1[2][0] + sh1[2][1] * sh1[1][0]
    sh2[3][2] = kSqrt03_04 * (sh1[1][1] * sh1[2][1] + sh1[2][1] * sh1[1][1])
    sh2[3][3] = sh1[1][1] * sh1[2][2] + sh1[2][1] * sh1[1][2]
    sh2[3][4] = kSqrt01_04 * (
        (sh1[1][2] * sh1[2][2] - sh1[1][0] * sh1[2][0])
        + (sh1[2][2] * sh1[1][2] - sh1[2][0] * sh1[1][0])
    )

    sh2[4][0] = kSqrt01_04 * (
        (sh1[2][2] * sh1[2][0] + sh1[2][0] * sh1[2][2])
        - (sh1[0][2] * sh1[0][0] + sh1[0][0] * sh1[0][2])
    )
    sh2[4][1] = sh1[2][1] * sh1[2][0] - sh1[0][1] * sh1[0][0]
    sh2[4][2] = kSqrt03_04 * (sh1[2][1] * sh1[2][1] - sh1[0][1] * sh1[0][1])
    sh2[4][3] = sh1[2][1] * sh1[2][2] - sh1[0][1] * sh1[0][2]
    sh2[4][4] = kSqrt01_04 * (
        (sh1[2][2] * sh1[2][2] - sh1[2][0] * sh1[2][0])
        - (sh1[0][2] * sh1[0][2] - sh1[0][0] * sh1[0][0])
    )

    return sh2


def get_sh3(sh1, sh2):
    sh3 = np.zeros((7, 7), dtype=sh1.dtype)

    sh3[0][0] = kSqrt01_04 * (
        (sh1[2][2] * sh2[0][0] + sh1[2][0] * sh2[0][4])
        + (sh1[0][2] * sh2[4][0] + sh1[0][0] * sh2[4][4])
    )
    sh3[0][1] = kSqrt03_02 * (sh1[2][1] * sh2[0][0] + sh1[0][1] * sh2[4][0])
    sh3[0][2] = kSqrt15_16 * (sh1[2][1] * sh2[0][1] + sh1[0][1] * sh2[4][1])
    sh3[0][3] = kSqrt05_06 * (sh1[2][1] * sh2[0][2] + sh1[0][1] * sh2[4][2])
    sh3[0][4] = kSqrt15_16 * (sh1[2][1] * sh2[0][3] + sh1[0][1] * sh2[4][3])
    sh3[0][5] = kSqrt03_02 * (sh1[2][1] * sh2[0][4] + sh1[0][1] * sh2[4][4])
    sh3[0][6] = kSqrt01_04 * (
        (sh1[2][2] * sh2[0][4] - sh1[2][0] * sh2[0][0])
        + (sh1[0][2] * sh2[4][4] - sh1[0][0] * sh2[4][0])
    )

    sh3[1][0] = kSqrt01_06 * (
        sh1[1][2] * sh2[0][0] + sh1[1][0] * sh2[0][4]
    ) + kSqrt01_06 * (
        (sh1[2][2] * sh2[1][0] + sh1[2][0] * sh2[1][4])
        + (sh1[0][2] * sh2[3][0] + sh1[0][0] * sh2[3][4])
    )
    sh3[1][1] = sh1[1][1] * sh2[0][0] + (sh1[2][1] * sh2[1][0] + sh1[0][1] * sh2[3][0])
    sh3[1][2] = kSqrt05_08 * sh1[1][1] * sh2[0][1] + kSqrt05_08 * (
        sh1[2][1] * sh2[1][1] + sh1[0][1] * sh2[3][1]
    )
    sh3[1][3] = kSqrt05_09 * sh1[1][1] * sh2[0][2] + kSqrt05_09 * (
        sh1[2][1] * sh2[1][2] + sh1[0][1] * sh2[3][2]
    )
    sh3[1][4] = kSqrt05_08 * sh1[1][1] * sh2[0][3] + kSqrt05_08 * (
        sh1[2][1] * sh2[1][3] + sh1[0][1] * sh2[3][3]
    )
    sh3[1][5] = sh1[1][1] * sh2[0][4] + (sh1[2][1] * sh2[1][4] + sh1[0][1] * sh2[3][4])
    sh3[1][6] = kSqrt01_06 * (
        sh1[1][2] * sh2[0][4] - sh1[1][0] * sh2[0][0]
    ) + kSqrt01_06 * (
        (sh1[2][2] * sh2[1][4] - sh1[2][0] * sh2[1][0])
        + (sh1[0][2] * sh2[3][4] - sh1[0][0] * sh2[3][0])
    )

    sh3[2][0] = (
        kSqrt04_15 * (sh1[1][2] * sh2[1][0] + sh1[1][0] * sh2[1][4])
        + kSqrt01_05 * (sh1[0][2] * sh2[2][0] + sh1[0][0] * sh2[2][4])
        - kSqrt01_60
        * (
            (sh1[2][2] * sh2[0][0] + sh1[2][0] * sh2[0][4])
            - (sh1[0][2] * sh2[4][0] + sh1[0][0] * sh2[4][4])
        )
    )
    sh3[2][1] = (
        kSqrt08_05 * sh1[1][1] * sh2[1][0]
        + kSqrt06_05 * sh1[0][1] * sh2[2][0]
        - kSqrt01_10 * (sh1[2][1] * sh2[0][0] - sh1[0][1] * sh2[4][0])
    )
    sh3[2][2] = (
        sh1[1][1] * sh2[1][1]
        + kSqrt03_04 * sh1[0][1] * sh2[2][1]
        - kSqrt01_16 * (sh1[2][1] * sh2[0][1] - sh1[0][1] * sh2[4][1])
    )
    sh3[2][3] = (
        kSqrt08_09 * sh1[1][1] * sh2[1][2]
        + kSqrt02_03 * sh1[0][1] * sh2[2][2]
        - kSqrt01_18 * (sh1[2][1] * sh2[0][2] - sh1[0][1] * sh2[4][2])
    )
    sh3[2][4] = (
        sh1[1][1] * sh2[1][3]
        + kSqrt03_04 * sh1[0][1] * sh2[2][3]
        - kSqrt01_16 * (sh1[2][1] * sh2[0][3] - sh1[0][1] * sh2[4][3])
    )
    sh3[2][5] = (
        kSqrt08_05 * sh1[1][1] * sh2[1][4]
        + kSqrt06_05 * sh1[0][1] * sh2[2][4]
        - kSqrt01_10 * (sh1[2][1] * sh2[0][4] - sh1[0][1] * sh2[4][4])
    )
    sh3[2][6] = (
        kSqrt04_15 * (sh1[1][2] * sh2[1][4] - sh1[1][0] * sh2[1][0])
        + kSqrt01_05 * (sh1[0][2] * sh2[2][4] - sh1[0][0] * sh2[2][0])
        - kSqrt01_60
        * (
            (sh1[2][2] * sh2[0][4] - sh1[2][0] * sh2[0][0])
            - (sh1[0][2] * sh2[4][4] - sh1[0][0] * sh2[4][0])
        )
    )

    sh3[3][0] = kSqrt03_10 * (
        sh1[1][2] * sh2[2][0] + sh1[1][0] * sh2[2][4]
    ) - kSqrt01_10 * (
        (sh1[2][2] * sh2[3][0] + sh1[2][0] * sh2[3][4])
        + (sh1[0][2] * sh2[1][0] + sh1[0][0] * sh2[1][4])
    )
    sh3[3][1] = kSqrt09_05 * sh1[1][1] * sh2[2][0] - kSqrt03_05 * (
        sh1[2][1] * sh2[3][0] + sh1[0][1] * sh2[1][0]
    )
    sh3[3][2] = kSqrt09_08 * sh1[1][1] * sh2[2][1] - kSqrt03_08 * (
        sh1[2][1] * sh2[3][1] + sh1[0][1] * sh2[1][1]
    )
    sh3[3][3] = sh1[1][1] * sh2[2][2] - kSqrt01_03 * (
        sh1[2][1] * sh2[3][2] + sh1[0][1] * sh2[1][2]
    )
    sh3[3][4] = kSqrt09_08 * sh1[1][1] * sh2[2][3] - kSqrt03_08 * (
        sh1[2][1] * sh2[3][3] + sh1[0][1] * sh2[1][3]
    )
    sh3[3][5] = kSqrt09_05 * sh1[1][1] * sh2[2][4] - kSqrt03_05 * (
        sh1[2][1] * sh2[3][4] + sh1[0][1] * sh2[1][4]
    )
    sh3[3][6] = kSqrt03_10 * (
        sh1[1][2] * sh2[2][4] - sh1[1][0] * sh2[2][0]
    ) - kSqrt01_10 * (
        (sh1[2][2] * sh2[3][4] - sh1[2][0] * sh2[3][0])
        + (sh1[0][2] * sh2[1][4] - sh1[0][0] * sh2[1][0])
    )

    sh3[4][0] = (
        kSqrt04_15 * (sh1[1][2] * sh2[3][0] + sh1[1][0] * sh2[3][4])
        + kSqrt01_05 * (sh1[2][2] * sh2[2][0] + sh1[2][0] * sh2[2][4])
        - kSqrt01_60
        * (
            (sh1[2][2] * sh2[4][0] + sh1[2][0] * sh2[4][4])
            + (sh1[0][2] * sh2[0][0] + sh1[0][0] * sh2[0][4])
        )
    )
    sh3[4][1] = (
        kSqrt08_05 * sh1[1][1] * sh2[3][0]
        + kSqrt06_05 * sh1[2][1] * sh2[2][0]
        - kSqrt01_10 * (sh1[2][1] * sh2[4][0] + sh1[0][1] * sh2[0][0])
    )
    sh3[4][2] = (
        sh1[1][1] * sh2[3][1]
        + kSqrt03_04 * sh1[2][1] * sh2[2][1]
        - kSqrt01_16 * (sh1[2][1] * sh2[4][1] + sh1[0][1] * sh2[0][1])
    )
    sh3[4][3] = (
        kSqrt08_09 * sh1[1][1] * sh2[3][2]
        + kSqrt02_03 * sh1[2][1] * sh2[2][2]
        - kSqrt01_18 * (sh1[2][1] * sh2[4][2] + sh1[0][1] * sh2[0][2])
    )
    sh3[4][4] = (
        sh1[1][1] * sh2[3][3]
        + kSqrt03_04 * sh1[2][1] * sh2[2][3]
        - kSqrt01_16 * (sh1[2][1] * sh2[4][3] + sh1[0][1] * sh2[0][3])
    )
    sh3[4][5] = (
        kSqrt08_05 * sh1[1][1] * sh2[3][4]
        + kSqrt06_05 * sh1[2][1] * sh2[2][4]
        - kSqrt01_10 * (sh1[2][1] * sh2[4][4] + sh1[0][1] * sh2[0][4])
    )
    sh3[4][6] = (
        kSqrt04_15 * (sh1[1][2] * sh2[3][4] - sh1[1][0] * sh2[3][0])
        + kSqrt01_05 * (sh1[2][2] * sh2[2][4] - sh1[2][0] * sh2[2][0])
        - kSqrt01_60
        * (
            (sh1[2][2] * sh2[4][4] - sh1[2][0] * sh2[4][0])
            + (sh1[0][2] * sh2[0][4] - sh1[0][0] * sh2[0][0])
        )
    )

    sh3[5][0] = kSqrt01_06 * (
        sh1[1][2] * sh2[4][0] + sh1[1][0] * sh2[4][4]
    ) + kSqrt01_06 * (
        (sh1[2][2] * sh2[3][0] + sh1[2][0] * sh2[3][4])
        - (sh1[0][2] * sh2[1][0] + sh1[0][0] * sh2[1][4])
    )
    sh3[5][1] = sh1[1][1] * sh2[4][0] + (sh1[2][1] * sh2[3][0] - sh1[0][1] * sh2[1][0])
    sh3[5][2] = kSqrt05_08 * sh1[1][1] * sh2[4][1] + kSqrt05_08 * (
        sh1[2][1] * sh2[3][1] - sh1[0][1] * sh2[1][1]
    )
    sh3[5][3] = kSqrt05_09 * sh1[1][1] * sh2[4][2] + kSqrt05_09 * (
        sh1[2][1] * sh2[3][2] - sh1[0][1] * sh2[1][2]
    )
    sh3[5][4] = kSqrt05_08 * sh1[1][1] * sh2[4][3] + kSqrt05_08 * (
        sh1[2][1] * sh2[3][3] - sh1[0][1] * sh2[1][3]
    )
    sh3[5][5] = sh1[1][1] * sh2[4][4] + (sh1[2][1] * sh2[3][4] - sh1[0][1] * sh2[1][4])
    sh3[5][6] = kSqrt01_06 * (
        sh1[1][2] * sh2[4][4] - sh1[1][0] * sh2[4][0]
    ) + kSqrt01_06 * (
        (sh1[2][2] * sh2[3][4] - sh1[2][0] * sh2[3][0])
        - (sh1[0][2] * sh2[1][4] - sh1[0][0] * sh2[1][0])
    )

    sh3[6][0] = kSqrt01_04 * (
        (sh1[2][2] * sh2[4][0] + sh1[2][0] * sh2[4][4])
        - (sh1[0][2] * sh2[0][0] + sh1[0][0] * sh2[0][4])
    )
    sh3[6][1] = kSqrt03_02 * (sh1[2][1] * sh2[4][0] - sh1[0][1] * sh2[0][0])
    sh3[6][2] = kSqrt15_16 * (sh1[2][1] * sh2[4][1] - sh1[0][1] * sh2[0][1])
    sh3[6][3] = kSqrt05_06 * (sh1[2][1] * sh2[4][2] - sh1[0][1] * sh2[0][2])
    sh3[6][4] = kSqrt15_16 * (sh1[2][1] * sh2[4][3] - sh1[0][1] * sh2[0][3])
    sh3[6][5] = kSqrt03_02 * (sh1[2][1] * sh2[4][4] - sh1[0][1] * sh2[0][4])
    sh3[6][6] = kSqrt01_04 * (
        (sh1[2][2] * sh2[4][4] - sh1[2][0] * sh2[4][0])
        - (sh1[0][2] * sh2[0][4] - sh1[0][0] * sh2[0][0])
    )

    return sh3


def get_sh4(sh1, sh2, sh3):

    sh4 = np.zeros((9, 9), dtype=sh1.dtype)

    sh4[0][0] = kSqrt01_04 * (
        (sh1[2][2] * sh3[0][0] + sh1[2][0] * sh3[0][6])
        + (sh1[0][2] * sh3[6][0] + sh1[0][0] * sh3[6][6])
    )
    sh4[0][1] = kSqrt02_01 * (sh1[2][1] * sh3[0][0] + sh1[0][1] * sh3[6][0])
    sh4[0][2] = kSqrt07_06 * (sh1[2][1] * sh3[0][1] + sh1[0][1] * sh3[6][1])
    sh4[0][3] = kSqrt14_15 * (sh1[2][1] * sh3[0][2] + sh1[0][1] * sh3[6][2])
    sh4[0][4] = kSqrt07_08 * (sh1[2][1] * sh3[0][3] + sh1[0][1] * sh3[6][3])
    sh4[0][5] = kSqrt14_15 * (sh1[2][1] * sh3[0][4] + sh1[0][1] * sh3[6][4])
    sh4[0][6] = kSqrt07_06 * (sh1[2][1] * sh3[0][5] + sh1[0][1] * sh3[6][5])
    sh4[0][7] = kSqrt02_01 * (sh1[2][1] * sh3[0][6] + sh1[0][1] * sh3[6][6])
    sh4[0][8] = kSqrt01_04 * (
        (sh1[2][2] * sh3[0][6] - sh1[2][0] * sh3[0][0])
        + (sh1[0][2] * sh3[6][6] - sh1[0][0] * sh3[6][0])
    )

    sh4[1][0] = kSqrt01_08 * (
        sh1[1][2] * sh3[0][0] + sh1[1][0] * sh3[0][6]
    ) + kSqrt03_16 * (
        (sh1[2][2] * sh3[1][0] + sh1[2][0] * sh3[1][6])
        + (sh1[0][2] * sh3[5][0] + sh1[0][0] * sh3[5][6])
    )
    sh4[1][1] = sh1[1][1] * sh3[0][0] + kSqrt03_02 * (
        sh1[2][1] * sh3[1][0] + sh1[0][1] * sh3[5][0]
    )
    sh4[1][2] = kSqrt07_12 * sh1[1][1] * sh3[0][1] + kSqrt07_08 * (
        sh1[2][1] * sh3[1][1] + sh1[0][1] * sh3[5][1]
    )
    sh4[1][3] = kSqrt07_15 * sh1[1][1] * sh3[0][2] + kSqrt07_10 * (
        sh1[2][1] * sh3[1][2] + sh1[0][1] * sh3[5][2]
    )
    sh4[1][4] = kSqrt07_16 * sh1[1][1] * sh3[0][3] + kSqrt21_32 * (
        sh1[2][1] * sh3[1][3] + sh1[0][1] * sh3[5][3]
    )
    sh4[1][5] = kSqrt07_15 * sh1[1][1] * sh3[0][4] + kSqrt07_10 * (
        sh1[2][1] * sh3[1][4] + sh1[0][1] * sh3[5][4]
    )
    sh4[1][6] = kSqrt07_12 * sh1[1][1] * sh3[0][5] + kSqrt07_08 * (
        sh1[2][1] * sh3[1][5] + sh1[0][1] * sh3[5][5]
    )
    sh4[1][7] = sh1[1][1] * sh3[0][6] + kSqrt03_02 * (
        sh1[2][1] * sh3[1][6] + sh1[0][1] * sh3[5][6]
    )
    sh4[1][8] = kSqrt01_08 * (
        sh1[1][2] * sh3[0][6] - sh1[1][0] * sh3[0][0]
    ) + kSqrt03_16 * (
        (sh1[2][2] * sh3[1][6] - sh1[2][0] * sh3[1][0])
        + (sh1[0][2] * sh3[5][6] - sh1[0][0] * sh3[5][0])
    )

    sh4[2][0] = (
        kSqrt03_14 * (sh1[1][2] * sh3[1][0] + sh1[1][0] * sh3[1][6])
        + kSqrt15_112
        * (
            (sh1[2][2] * sh3[2][0] + sh1[2][0] * sh3[2][6])
            + (sh1[0][2] * sh3[4][0] + sh1[0][0] * sh3[4][6])
        )
        - kSqrt01_112
        * (
            (sh1[2][2] * sh3[0][0] + sh1[2][0] * sh3[0][6])
            - (sh1[0][2] * sh3[6][0] + sh1[0][0] * sh3[6][6])
        )
    )
    sh4[2][1] = (
        kSqrt12_07 * sh1[1][1] * sh3[1][0]
        + kSqrt15_14 * (sh1[2][1] * sh3[2][0] + sh1[0][1] * sh3[4][0])
        - kSqrt01_14 * (sh1[2][1] * sh3[0][0] - sh1[0][1] * sh3[6][0])
    )
    sh4[2][2] = (
        sh1[1][1] * sh3[1][1]
        + kSqrt05_08 * (sh1[2][1] * sh3[2][1] + sh1[0][1] * sh3[4][1])
        - kSqrt01_24 * (sh1[2][1] * sh3[0][1] - sh1[0][1] * sh3[6][1])
    )
    sh4[2][3] = (
        kSqrt04_05 * sh1[1][1] * sh3[1][2]
        + kSqrt01_02 * (sh1[2][1] * sh3[2][2] + sh1[0][1] * sh3[4][2])
        - kSqrt01_30 * (sh1[2][1] * sh3[0][2] - sh1[0][1] * sh3[6][2])
    )
    sh4[2][4] = (
        kSqrt03_04 * sh1[1][1] * sh3[1][3]
        + kSqrt15_32 * (sh1[2][1] * sh3[2][3] + sh1[0][1] * sh3[4][3])
        - kSqrt01_32 * (sh1[2][1] * sh3[0][3] - sh1[0][1] * sh3[6][3])
    )
    sh4[2][5] = (
        kSqrt04_05 * sh1[1][1] * sh3[1][4]
        + kSqrt01_02 * (sh1[2][1] * sh3[2][4] + sh1[0][1] * sh3[4][4])
        - kSqrt01_30 * (sh1[2][1] * sh3[0][4] - sh1[0][1] * sh3[6][4])
    )
    sh4[2][6] = (
        sh1[1][1] * sh3[1][5]
        + kSqrt05_08 * (sh1[2][1] * sh3[2][5] + sh1[0][1] * sh3[4][5])
        - kSqrt01_24 * (sh1[2][1] * sh3[0][5] - sh1[0][1] * sh3[6][5])
    )
    sh4[2][7] = (
        kSqrt12_07 * sh1[1][1] * sh3[1][6]
        + kSqrt15_14 * (sh1[2][1] * sh3[2][6] + sh1[0][1] * sh3[4][6])
        - kSqrt01_14 * (sh1[2][1] * sh3[0][6] - sh1[0][1] * sh3[6][6])
    )
    sh4[2][8] = (
        kSqrt03_14 * (sh1[1][2] * sh3[1][6] - sh1[1][0] * sh3[1][0])
        + kSqrt15_112
        * (
            (sh1[2][2] * sh3[2][6] - sh1[2][0] * sh3[2][0])
            + (sh1[0][2] * sh3[4][6] - sh1[0][0] * sh3[4][0])
        )
        - kSqrt01_112
        * (
            (sh1[2][2] * sh3[0][6] - sh1[2][0] * sh3[0][0])
            - (sh1[0][2] * sh3[6][6] - sh1[0][0] * sh3[6][0])
        )
    )

    sh4[3][0] = (
        kSqrt15_56 * (sh1[1][2] * sh3[2][0] + sh1[1][0] * sh3[2][6])
        + kSqrt05_28 * (sh1[0][2] * sh3[3][0] + sh1[0][0] * sh3[3][6])
        - kSqrt03_112
        * (
            (sh1[2][2] * sh3[1][0] + sh1[2][0] * sh3[1][6])
            - (sh1[0][2] * sh3[5][0] + sh1[0][0] * sh3[5][6])
        )
    )
    sh4[3][1] = (
        kSqrt15_07 * sh1[1][1] * sh3[2][0]
        + kSqrt10_07 * sh1[0][1] * sh3[3][0]
        - kSqrt03_14 * (sh1[2][1] * sh3[1][0] - sh1[0][1] * sh3[5][0])
    )
    sh4[3][2] = (
        kSqrt05_04 * sh1[1][1] * sh3[2][1]
        + kSqrt05_06 * sh1[0][1] * sh3[3][1]
        - kSqrt01_08 * (sh1[2][1] * sh3[1][1] - sh1[0][1] * sh3[5][1])
    )
    sh4[3][3] = (
        sh1[1][1] * sh3[2][2]
        + kSqrt02_03 * sh1[0][1] * sh3[3][2]
        - kSqrt01_10 * (sh1[2][1] * sh3[1][2] - sh1[0][1] * sh3[5][2])
    )
    sh4[3][4] = (
        kSqrt15_16 * sh1[1][1] * sh3[2][3]
        + kSqrt05_08 * sh1[0][1] * sh3[3][3]
        - kSqrt03_32 * (sh1[2][1] * sh3[1][3] - sh1[0][1] * sh3[5][3])
    )
    sh4[3][5] = (
        sh1[1][1] * sh3[2][4]
        + kSqrt02_03 * sh1[0][1] * sh3[3][4]
        - kSqrt01_10 * (sh1[2][1] * sh3[1][4] - sh1[0][1] * sh3[5][4])
    )
    sh4[3][6] = (
        kSqrt05_04 * sh1[1][1] * sh3[2][5]
        + kSqrt05_06 * sh1[0][1] * sh3[3][5]
        - kSqrt01_08 * (sh1[2][1] * sh3[1][5] - sh1[0][1] * sh3[5][5])
    )
    sh4[3][7] = (
        kSqrt15_07 * sh1[1][1] * sh3[2][6]
        + kSqrt10_07 * sh1[0][1] * sh3[3][6]
        - kSqrt03_14 * (sh1[2][1] * sh3[1][6] - sh1[0][1] * sh3[5][6])
    )
    sh4[3][8] = (
        kSqrt15_56 * (sh1[1][2] * sh3[2][6] - sh1[1][0] * sh3[2][0])
        + kSqrt05_28 * (sh1[0][2] * sh3[3][6] - sh1[0][0] * sh3[3][0])
        - kSqrt03_112
        * (
            (sh1[2][2] * sh3[1][6] - sh1[2][0] * sh3[1][0])
            - (sh1[0][2] * sh3[5][6] - sh1[0][0] * sh3[5][0])
        )
    )

    sh4[4][0] = kSqrt02_07 * (
        sh1[1][2] * sh3[3][0] + sh1[1][0] * sh3[3][6]
    ) - kSqrt03_28 * (
        (sh1[2][2] * sh3[4][0] + sh1[2][0] * sh3[4][6])
        + (sh1[0][2] * sh3[2][0] + sh1[0][0] * sh3[2][6])
    )
    sh4[4][1] = kSqrt16_07 * sh1[1][1] * sh3[3][0] - kSqrt06_07 * (
        sh1[2][1] * sh3[4][0] + sh1[0][1] * sh3[2][0]
    )
    sh4[4][2] = kSqrt04_03 * sh1[1][1] * sh3[3][1] - kSqrt01_02 * (
        sh1[2][1] * sh3[4][1] + sh1[0][1] * sh3[2][1]
    )
    sh4[4][3] = kSqrt16_15 * sh1[1][1] * sh3[3][2] - kSqrt02_05 * (
        sh1[2][1] * sh3[4][2] + sh1[0][1] * sh3[2][2]
    )
    sh4[4][4] = sh1[1][1] * sh3[3][3] - kSqrt03_08 * (
        sh1[2][1] * sh3[4][3] + sh1[0][1] * sh3[2][3]
    )
    sh4[4][5] = kSqrt16_15 * sh1[1][1] * sh3[3][4] - kSqrt02_05 * (
        sh1[2][1] * sh3[4][4] + sh1[0][1] * sh3[2][4]
    )
    sh4[4][6] = kSqrt04_03 * sh1[1][1] * sh3[3][5] - kSqrt01_02 * (
        sh1[2][1] * sh3[4][5] + sh1[0][1] * sh3[2][5]
    )
    sh4[4][7] = kSqrt16_07 * sh1[1][1] * sh3[3][6] - kSqrt06_07 * (
        sh1[2][1] * sh3[4][6] + sh1[0][1] * sh3[2][6]
    )
    sh4[4][8] = kSqrt02_07 * (
        sh1[1][2] * sh3[3][6] - sh1[1][0] * sh3[3][0]
    ) - kSqrt03_28 * (
        (sh1[2][2] * sh3[4][6] - sh1[2][0] * sh3[4][0])
        + (sh1[0][2] * sh3[2][6] - sh1[0][0] * sh3[2][0])
    )

    sh4[5][0] = (
        kSqrt15_56 * (sh1[1][2] * sh3[4][0] + sh1[1][0] * sh3[4][6])
        + kSqrt05_28 * (sh1[2][2] * sh3[3][0] + sh1[2][0] * sh3[3][6])
        - kSqrt03_112
        * (
            (sh1[2][2] * sh3[5][0] + sh1[2][0] * sh3[5][6])
            + (sh1[0][2] * sh3[1][0] + sh1[0][0] * sh3[1][6])
        )
    )
    sh4[5][1] = (
        kSqrt15_07 * sh1[1][1] * sh3[4][0]
        + kSqrt10_07 * sh1[2][1] * sh3[3][0]
        - kSqrt03_14 * (sh1[2][1] * sh3[5][0] + sh1[0][1] * sh3[1][0])
    )
    sh4[5][2] = (
        kSqrt05_04 * sh1[1][1] * sh3[4][1]
        + kSqrt05_06 * sh1[2][1] * sh3[3][1]
        - kSqrt01_08 * (sh1[2][1] * sh3[5][1] + sh1[0][1] * sh3[1][1])
    )
    sh4[5][3] = (
        sh1[1][1] * sh3[4][2]
        + kSqrt02_03 * sh1[2][1] * sh3[3][2]
        - kSqrt01_10 * (sh1[2][1] * sh3[5][2] + sh1[0][1] * sh3[1][2])
    )
    sh4[5][4] = (
        kSqrt15_16 * sh1[1][1] * sh3[4][3]
        + kSqrt05_08 * sh1[2][1] * sh3[3][3]
        - kSqrt03_32 * (sh1[2][1] * sh3[5][3] + sh1[0][1] * sh3[1][3])
    )
    sh4[5][5] = (
        sh1[1][1] * sh3[4][4]
        + kSqrt02_03 * sh1[2][1] * sh3[3][4]
        - kSqrt01_10 * (sh1[2][1] * sh3[5][4] + sh1[0][1] * sh3[1][4])
    )
    sh4[5][6] = (
        kSqrt05_04 * sh1[1][1] * sh3[4][5]
        + kSqrt05_06 * sh1[2][1] * sh3[3][5]
        - kSqrt01_08 * (sh1[2][1] * sh3[5][5] + sh1[0][1] * sh3[1][5])
    )
    sh4[5][7] = (
        kSqrt15_07 * sh1[1][1] * sh3[4][6]
        + kSqrt10_07 * sh1[2][1] * sh3[3][6]
        - kSqrt03_14 * (sh1[2][1] * sh3[5][6] + sh1[0][1] * sh3[1][6])
    )
    sh4[5][8] = (
        kSqrt15_56 * (sh1[1][2] * sh3[4][6] - sh1[1][0] * sh3[4][0])
        + kSqrt05_28 * (sh1[2][2] * sh3[3][6] - sh1[2][0] * sh3[3][0])
        - kSqrt03_112
        * (
            (sh1[2][2] * sh3[5][6] - sh1[2][0] * sh3[5][0])
            + (sh1[0][2] * sh3[1][6] - sh1[0][0] * sh3[1][0])
        )
    )

    sh4[6][0] = (
        kSqrt03_14 * (sh1[1][2] * sh3[5][0] + sh1[1][0] * sh3[5][6])
        + kSqrt15_112
        * (
            (sh1[2][2] * sh3[4][0] + sh1[2][0] * sh3[4][6])
            - (sh1[0][2] * sh3[2][0] + sh1[0][0] * sh3[2][6])
        )
        - kSqrt01_112
        * (
            (sh1[2][2] * sh3[6][0] + sh1[2][0] * sh3[6][6])
            + (sh1[0][2] * sh3[0][0] + sh1[0][0] * sh3[0][6])
        )
    )
    sh4[6][1] = (
        kSqrt12_07 * sh1[1][1] * sh3[5][0]
        + kSqrt15_14 * (sh1[2][1] * sh3[4][0] - sh1[0][1] * sh3[2][0])
        - kSqrt01_14 * (sh1[2][1] * sh3[6][0] + sh1[0][1] * sh3[0][0])
    )
    sh4[6][2] = (
        sh1[1][1] * sh3[5][1]
        + kSqrt05_08 * (sh1[2][1] * sh3[4][1] - sh1[0][1] * sh3[2][1])
        - kSqrt01_24 * (sh1[2][1] * sh3[6][1] + sh1[0][1] * sh3[0][1])
    )
    sh4[6][3] = (
        kSqrt04_05 * sh1[1][1] * sh3[5][2]
        + kSqrt01_02 * (sh1[2][1] * sh3[4][2] - sh1[0][1] * sh3[2][2])
        - kSqrt01_30 * (sh1[2][1] * sh3[6][2] + sh1[0][1] * sh3[0][2])
    )
    sh4[6][4] = (
        kSqrt03_04 * sh1[1][1] * sh3[5][3]
        + kSqrt15_32 * (sh1[2][1] * sh3[4][3] - sh1[0][1] * sh3[2][3])
        - kSqrt01_32 * (sh1[2][1] * sh3[6][3] + sh1[0][1] * sh3[0][3])
    )
    sh4[6][5] = (
        kSqrt04_05 * sh1[1][1] * sh3[5][4]
        + kSqrt01_02 * (sh1[2][1] * sh3[4][4] - sh1[0][1] * sh3[2][4])
        - kSqrt01_30 * (sh1[2][1] * sh3[6][4] + sh1[0][1] * sh3[0][4])
    )
    sh4[6][6] = (
        sh1[1][1] * sh3[5][5]
        + kSqrt05_08 * (sh1[2][1] * sh3[4][5] - sh1[0][1] * sh3[2][5])
        - kSqrt01_24 * (sh1[2][1] * sh3[6][5] + sh1[0][1] * sh3[0][5])
    )
    sh4[6][7] = (
        kSqrt12_07 * sh1[1][1] * sh3[5][6]
        + kSqrt15_14 * (sh1[2][1] * sh3[4][6] - sh1[0][1] * sh3[2][6])
        - kSqrt01_14 * (sh1[2][1] * sh3[6][6] + sh1[0][1] * sh3[0][6])
    )
    sh4[6][8] = (
        kSqrt03_14 * (sh1[1][2] * sh3[5][6] - sh1[1][0] * sh3[5][0])
        + kSqrt15_112
        * (
            (sh1[2][2] * sh3[4][6] - sh1[2][0] * sh3[4][0])
            - (sh1[0][2] * sh3[2][6] - sh1[0][0] * sh3[2][0])
        )
        - kSqrt01_112
        * (
            (sh1[2][2] * sh3[6][6] - sh1[2][0] * sh3[6][0])
            + (sh1[0][2] * sh3[0][6] - sh1[0][0] * sh3[0][0])
        )
    )

    sh4[7][0] = kSqrt01_08 * (
        sh1[1][2] * sh3[6][0] + sh1[1][0] * sh3[6][6]
    ) + kSqrt03_16 * (
        (sh1[2][2] * sh3[5][0] + sh1[2][0] * sh3[5][6])
        - (sh1[0][2] * sh3[1][0] + sh1[0][0] * sh3[1][6])
    )
    sh4[7][1] = sh1[1][1] * sh3[6][0] + kSqrt03_02 * (
        sh1[2][1] * sh3[5][0] - sh1[0][1] * sh3[1][0]
    )
    sh4[7][2] = kSqrt07_12 * sh1[1][1] * sh3[6][1] + kSqrt07_08 * (
        sh1[2][1] * sh3[5][1] - sh1[0][1] * sh3[1][1]
    )
    sh4[7][3] = kSqrt07_15 * sh1[1][1] * sh3[6][2] + kSqrt07_10 * (
        sh1[2][1] * sh3[5][2] - sh1[0][1] * sh3[1][2]
    )
    sh4[7][4] = kSqrt07_16 * sh1[1][1] * sh3[6][3] + kSqrt21_32 * (
        sh1[2][1] * sh3[5][3] - sh1[0][1] * sh3[1][3]
    )
    sh4[7][5] = kSqrt07_15 * sh1[1][1] * sh3[6][4] + kSqrt07_10 * (
        sh1[2][1] * sh3[5][4] - sh1[0][1] * sh3[1][4]
    )
    sh4[7][6] = kSqrt07_12 * sh1[1][1] * sh3[6][5] + kSqrt07_08 * (
        sh1[2][1] * sh3[5][5] - sh1[0][1] * sh3[1][5]
    )
    sh4[7][7] = sh1[1][1] * sh3[6][6] + kSqrt03_02 * (
        sh1[2][1] * sh3[5][6] - sh1[0][1] * sh3[1][6]
    )
    sh4[7][8] = kSqrt01_08 * (
        sh1[1][2] * sh3[6][6] - sh1[1][0] * sh3[6][0]
    ) + kSqrt03_16 * (
        (sh1[2][2] * sh3[5][6] - sh1[2][0] * sh3[5][0])
        - (sh1[0][2] * sh3[1][6] - sh1[0][0] * sh3[1][0])
    )

    sh4[8][0] = kSqrt01_04 * (
        (sh1[2][2] * sh3[6][0] + sh1[2][0] * sh3[6][6])
        - (sh1[0][2] * sh3[0][0] + sh1[0][0] * sh3[0][6])
    )
    sh4[8][1] = kSqrt02_01 * (sh1[2][1] * sh3[6][0] - sh1[0][1] * sh3[0][0])
    sh4[8][2] = kSqrt07_06 * (sh1[2][1] * sh3[6][1] - sh1[0][1] * sh3[0][1])
    sh4[8][3] = kSqrt14_15 * (sh1[2][1] * sh3[6][2] - sh1[0][1] * sh3[0][2])
    sh4[8][4] = kSqrt07_08 * (sh1[2][1] * sh3[6][3] - sh1[0][1] * sh3[0][3])
    sh4[8][5] = kSqrt14_15 * (sh1[2][1] * sh3[6][4] - sh1[0][1] * sh3[0][4])
    sh4[8][6] = kSqrt07_06 * (sh1[2][1] * sh3[6][5] - sh1[0][1] * sh3[0][5])
    sh4[8][7] = kSqrt02_01 * (sh1[2][1] * sh3[6][6] - sh1[0][1] * sh3[0][6])
    sh4[8][8] = kSqrt01_04 * (
        (sh1[2][2] * sh3[6][6] - sh1[2][0] * sh3[6][0])
        - (sh1[0][2] * sh3[0][6] - sh1[0][0] * sh3[0][0])
    )

    return sh4


class SHRotator:
    def __init__(self, R, deg=3):
        self.deg = deg
        if deg > 0:
            self.sh1 = get_sh1(R)
        if deg > 1:
            self.sh2 = get_sh2(self.sh1)
        if deg > 2:
            self.sh3 = get_sh3(self.sh1, self.sh2)
        if deg > 3:
            self.sh4 = get_sh4(self.sh1, self.sh2, self.sh3)
        if deg > 4:
            raise NotImplementedError

    def __call__(self, shs_in):
        # shs_in: (n, deg)
        shs_out = []
        # deg 0
        shs_out.append(shs_in[..., 0:1])
        if self.deg > 0:
            shs_out.append((self.sh1 @ shs_in[..., 1:4].T).T)
        if self.deg > 1:
            shs_out.append((self.sh2 @ shs_in[..., 4:9].T).T)
        if self.deg > 2:
            shs_out.append((self.sh3 @ shs_in[..., 9:16].T).T)
        if self.deg > 3:
            shs_out.append((self.sh4 @ shs_in[..., 16:25].T).T)
        if self.deg > 4:
            raise NotImplementedError
        shs_out = np.concatenate(shs_out, axis=-1)
        assert shs_out.shape == shs_in.shape
        return shs_out
